#!/usr/bin/env python3
import argparse, re, csv, math
from pathlib import Path
from collections import Counter

try:
    from scipy.stats import ttest_rel, wilcoxon
    SCIPY_OK = True
except Exception:
    SCIPY_OK = False

ITEM_SPLIT_RE = re.compile(r"(?m)^\s*(\d{1,2})\.\s")

CONDITION_NAMES = ["BaselineNoJSON", "BaselineWithJSON", "HighRigor", "PushbackStrong"]

def trim_item50_noise(text: str) -> str:
    m = re.search(r"(?ims)\r?\n\r?\n+please note\b", text)
    if m:
        return text[:m.start()].rstrip()
    return text.rstrip()

def extract_items(text: str) -> dict[int, str]:
    splits = ITEM_SPLIT_RE.split(text)
    items = {}
    for i in range(1, len(splits), 2):
        num = int(splits[i])
        content = splits[i + 1].strip()
        items[num] = content
    if 50 in items:
        items[50] = trim_item50_noise(items[50])
    if set(items.keys()) != set(range(1, 51)):
        raise ValueError("Bad keys: " + str(sorted(items.keys())))
# Sanity check to ensure items are correctly parsed
#    for idx in [1, 15, 30, 50]:
#        print(f"\nITEM {idx}")
#        print("Word count:", len(items[idx].split()))
#        print("START:", repr(items[idx][:200]))
#        print("END:", repr(items[idx][-200:]))    
    return items

def normalize(s: str) -> str:
    return re.sub(r"\s+", " ", s.strip())

def word_count(s: str) -> int:
    return len(s.split())

def mean(xs):
    return sum(xs) / len(xs) if xs else float("nan")

def stdev(xs):
    if len(xs) < 2:
        return float("nan")
    m = mean(xs)
    return math.sqrt(sum((x - m) ** 2 for x in xs) / (len(xs) - 1))

def cohens_dz(diffs):
    sd = stdev(diffs)
    if not diffs or sd == 0 or math.isnan(sd):
        return float("nan")
    return mean(diffs) / sd

def load_runs(base_dir: Path, filenames: list[str]) -> list[dict[int, str]]:
    runs = []
    for fn in filenames:
        p = base_dir / fn
        if not p.exists():
            raise FileNotFoundError(f"Missing file: {p}")
        text = p.read_text(encoding="utf-8", errors="ignore")
        runs.append(extract_items(text))
    return runs

def within_identical_items(runs):
    identical = 0
    for i in range(1, 51):
        texts = [normalize(r[i]) for r in runs]
        if len(set(texts)) == 1:
            identical += 1
    return identical

def consensus_text(runs, i):
    c = Counter(normalize(r[i]) for r in runs)
    return c.most_common(1)[0][0]

def between_exact_items(cpu_runs, gpu_runs):
    same = 0
    for i in range(1, 51):
        if consensus_text(cpu_runs, i) == consensus_text(gpu_runs, i):
            same += 1
    return same

def mean_wordcount_per_item(runs):
    out = []
    for i in range(1, 51):
        out.append(mean([word_count(r[i]) for r in runs]))
    return out

def paired_stats(cpu_vals, gpu_vals, alternative="greater"):
    if not SCIPY_OK:
        return (None, None, None, None)

    diffs = [c - g for c, g in zip(cpu_vals, gpu_vals)]
    # Degenerate cases -> avoid Wilcoxon z/SE warnings
    if len(set(diffs)) <= 1:
        # All equal
        if all(d == 0 for d in diffs):
            return (0.0, 1.0, 0.0, 1.0)
        # Constant nonzero shift
        # Wilcoxon is not informative here; report p≈0 for emphasis
        t_res = ttest_rel(cpu_vals, gpu_vals)
        return (float(t_res.statistic), float(t_res.pvalue), float("inf"), 0.0)

    t_res = ttest_rel(cpu_vals, gpu_vals)
    w_res = wilcoxon(cpu_vals, gpu_vals, alternative=alternative)
    return (float(t_res.statistic), float(t_res.pvalue), float(w_res.statistic), float(w_res.pvalue))

def build_files(cond, hw, tags):
    return [f"{cond}_{hw}_T0_{t}.txt" for t in tags]

def analyze_condition(base_dir, cond, tags):
    cpu_files = build_files(cond, "CPU", tags)
    gpu_files = build_files(cond, "GPU", tags)

    cpu_runs = load_runs(base_dir, cpu_files)
    gpu_runs = load_runs(base_dir, gpu_files)

    cpu_within = within_identical_items(cpu_runs)
    gpu_within = within_identical_items(gpu_runs)

    between_same = between_exact_items(cpu_runs, gpu_runs)
    between_diff = 50 - between_same

    cpu_wc = mean_wordcount_per_item(cpu_runs)
    gpu_wc = mean_wordcount_per_item(gpu_runs)
    diffs = [c - g for c, g in zip(cpu_wc, gpu_wc)]
    d_z = cohens_dz(diffs)

    t_stat, t_p, w_stat, w_p = paired_stats(cpu_wc, gpu_wc, alternative="greater")

    return {
        "condition": cond,
        "cpu_within_identical_items": cpu_within,
        "gpu_within_identical_items": gpu_within,
        "between_exact_items": between_same,
        "between_diff_items": between_diff,
        "mean_wordcount_cpu": mean(cpu_wc),
        "mean_wordcount_gpu": mean(gpu_wc),
        "mean_wordcount_diff_cpu_minus_gpu": mean(diffs),
        "sd_wordcount_diff_cpu_minus_gpu": stdev(diffs),
        "cohens_dz_wordcount_cpu_minus_gpu": d_z,
        "paired_t_stat": t_stat,
        "paired_t_p": t_p,
        "wilcoxon_stat": w_stat,
        "wilcoxon_p_one_sided_cpu_gt_gpu": w_p,
    }

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base_dir", required=True)
    ap.add_argument("--group", choices=["Warm", "Alt"], default="Warm")
    ap.add_argument("--out_csv", default="")
    args = ap.parse_args()

    tags = ["W2", "W3", "W4"] if args.group == "Warm" else ["A1", "A2", "A3"]
    base_dir = Path(args.base_dir)

    results = [analyze_condition(base_dir, cond, tags) for cond in CONDITION_NAMES]

    print(f"\n=== CPU vs GPU Summary (per condition) [{args.group}] ===\n")
    for r in results:
        print(f"[{r['condition']}]")
        print(f"  Within determinism (items identical across 3 runs): CPU {r['cpu_within_identical_items']}/50, GPU {r['gpu_within_identical_items']}/50")
        print(f"  Between CPU/GPU: exact-match items {r['between_exact_items']}/50  (diff {r['between_diff_items']}/50)")
        print(f"  Wordcount mean: CPU {r['mean_wordcount_cpu']:.2f}, GPU {r['mean_wordcount_gpu']:.2f}")
        print(f"  Wordcount diff (CPU-GPU): mean {r['mean_wordcount_diff_cpu_minus_gpu']:.2f}, SD {r['sd_wordcount_diff_cpu_minus_gpu']:.2f}")
        print(f"  Cohen's dz (paired; CPU-GPU): {r['cohens_dz_wordcount_cpu_minus_gpu']:.3f}")
        if SCIPY_OK:
            print(f"  Paired t-test: t={r['paired_t_stat']:.3f}, p={r['paired_t_p']:.6g}")
            print(f"  Wilcoxon (one-sided CPU>GPU): W={r['wilcoxon_stat']}, p={r['wilcoxon_p_one_sided_cpu_gt_gpu']}")
        else:
            print("  SciPy not available: skipping t-test/Wilcoxon.")
        print()

    if args.out_csv:
        out_path = Path(args.out_csv)
        with out_path.open("w", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=list(results[0].keys()))
            writer.writeheader()
            writer.writerows(results)
        print(f"Wrote CSV: {out_path}")

if __name__ == "__main__":
    main()